{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RNN language model\n",
    "Loosely based on [Zaremba et al. 2014](https://arxiv.org/abs/1409.2329), this example trains a word based RNN language model on Mikolov's PTB data with 10K vocab. It uses the `batchSizes` feature of `rnnforw` to process batches with different sized sentences. The `mb` minibatching function sorts sentences in a corpus by length and tries to group similarly sized sentences together. For an example that uses fixed length batches and goes across sentence boundaries see the [charlm](https://github.com/denizyuret/Knet.jl/blob/master/tutorial/08.charlm.ipynb) notebook. **TODO:** convert to the new RNN interface."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "EPOCHS=10\n",
    "RNNTYPE=:lstm\n",
    "BATCHSIZE=64\n",
    "EMBEDSIZE=128\n",
    "HIDDENSIZE=256\n",
    "VOCABSIZE=10000\n",
    "NUMLAYERS=1\n",
    "DROPOUT=0.5\n",
    "LR=0.001\n",
    "BETA_1=0.9\n",
    "BETA_2=0.999\n",
    "EPS=1e-08;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "42068-element Array{Array{UInt16,1},1}\n",
      "3370-element Array{Array{UInt16,1},1}\n",
      "3761-element Array{Array{UInt16,1},1}\n",
      "9999-element Array{String,1}\n"
     ]
    }
   ],
   "source": [
    "# Load data\n",
    "using Knet\n",
    "include(Knet.dir(\"data\",\"mikolovptb.jl\"))\n",
    "(trn,val,tst,vocab) = mikolovptb()\n",
    "@assert VOCABSIZE == length(vocab)+1 # +1 for the EOS token\n",
    "for x in (trn,val,tst,vocab); println(summary(x)); end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "UInt16[0x008e, 0x004e, 0x0036, 0x00fb, 0x0938, 0x0195]\n",
      "[\"no\", \"it\", \"was\", \"n't\", \"black\", \"monday\"]\n"
     ]
    }
   ],
   "source": [
    "# Print a sample\n",
    "println(tst[1])\n",
    "println(vocab[tst[1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/latex": [
       "\\begin{verbatim}\n",
       "mikolovptb()\n",
       "\\end{verbatim}\n",
       "Read \\href{https://catalog.ldc.upenn.edu/ldc99t42}{PTB} text from Mikolov's \\href{http://www.fit.vutbr.cz/~imikolov/rnnlm}{RNNLM} toolkit which has been lowercased and reduced to a 10K vocabulary size.  Return a tuple (trn,dev,tst,vocab) where\n",
       "\n",
       "\\begin{verbatim}\n",
       "trn::Vector{Vector{UInt16}}: 42068 sentences, 887521 words\n",
       "dev::Vector{Vector{UInt16}}: 3370 sentences, 70390 words\n",
       "tst::Vector{Vector{UInt16}}: 3761 sentences, 78669 words\n",
       "vocab::Vector{String}: 9999 unique words\n",
       "\\end{verbatim}\n"
      ],
      "text/markdown": [
       "```\n",
       "mikolovptb()\n",
       "```\n",
       "\n",
       "Read [PTB](https://catalog.ldc.upenn.edu/ldc99t42) text from Mikolov's [RNNLM](http://www.fit.vutbr.cz/~imikolov/rnnlm) toolkit which has been lowercased and reduced to a 10K vocabulary size.  Return a tuple (trn,dev,tst,vocab) where\n",
       "\n",
       "```\n",
       "trn::Vector{Vector{UInt16}}: 42068 sentences, 887521 words\n",
       "dev::Vector{Vector{UInt16}}: 3370 sentences, 70390 words\n",
       "tst::Vector{Vector{UInt16}}: 3761 sentences, 78669 words\n",
       "vocab::Vector{String}: 9999 unique words\n",
       "```\n"
      ],
      "text/plain": [
       "\u001b[36m  mikolovptb()\u001b[39m\n",
       "\n",
       "  Read PTB (https://catalog.ldc.upenn.edu/ldc99t42) text from Mikolov's RNNLM (http://www.fit.vutbr.cz/~imikolov/rnnlm) toolkit\n",
       "  which has been lowercased and reduced to a 10K vocabulary size. Return a tuple (trn,dev,tst,vocab) where\n",
       "\n",
       "\u001b[36m  trn::Vector{Vector{UInt16}}: 42068 sentences, 887521 words\u001b[39m\n",
       "\u001b[36m  dev::Vector{Vector{UInt16}}: 3370 sentences, 70390 words\u001b[39m\n",
       "\u001b[36m  tst::Vector{Vector{UInt16}}: 3761 sentences, 78669 words\u001b[39m\n",
       "\u001b[36m  vocab::Vector{String}: 9999 unique words\u001b[39m"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@doc mikolovptb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(658, 53, 59)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Minibatch data into (x,y,b) triples. This is the most complicated part of the code:\n",
    "# for language models x and y contain the same words shifted, x has an EOS in the beginning, y has an EOS at the end\n",
    "# x,y = [ s11,s21,s31,...,s12,s22,...] i.e. all the first words followed by all the second words etc.\n",
    "# b = [b1,b2,...,bT] i.e. how many sentences have first words, how many have second words etc.\n",
    "# length(x)==length(y)==sum(b) and length(b)=length(s1)+1 (+1 because of EOS)\n",
    "# sentences in batch should be sorted from longest to shortest, i.e. s1 is the longest sentence\n",
    "function mb(sentences,batchsize)\n",
    "    sentences = sort(sentences,by=length,rev=true)\n",
    "    data = []; eos = VOCABSIZE\n",
    "    for i = 1:batchsize:length(sentences)\n",
    "        j = min(i+batchsize-1,length(sentences))\n",
    "        sij = view(sentences,i:j)\n",
    "        T = 1+length(sij[1])\n",
    "        x = UInt16[]; y = UInt16[]; b = UInt16[]\n",
    "        for t=1:T\n",
    "            bt = 0\n",
    "            for s in sij\n",
    "                if t == 1\n",
    "                    push!(x,eos)\n",
    "                    push!(y,s[1])\n",
    "                elseif t <= length(s)\n",
    "                    push!(x,s[t-1])\n",
    "                    push!(y,s[t])\n",
    "                elseif t == 1+length(s)\n",
    "                    push!(x,s[t-1])\n",
    "                    push!(y,eos)\n",
    "                else\n",
    "                    break\n",
    "                end\n",
    "                bt += 1\n",
    "            end\n",
    "            push!(b,bt)\n",
    "        end\n",
    "        push!(data,(x,y,b))\n",
    "    end\n",
    "    return data\n",
    "end\n",
    "\n",
    "mbtrn = mb(trn,BATCHSIZE)\n",
    "mbval = mb(val,BATCHSIZE)\n",
    "mbtst = mb(tst,BATCHSIZE)\n",
    "map(length,(mbtrn,mbval,mbtst))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define model\n",
    "function initmodel()\n",
    "    w(d...)=KnetArray(xavier(Float32,d...))\n",
    "    b(d...)=KnetArray(zeros(Float32,d...))\n",
    "    r,wr = rnninit(EMBEDSIZE,HIDDENSIZE,rnnType=RNNTYPE,numLayers=NUMLAYERS,dropout=DROPOUT,atype=KnetArray{Float32})\n",
    "    wx = w(EMBEDSIZE,VOCABSIZE)\n",
    "    wy = w(VOCABSIZE,HIDDENSIZE)\n",
    "    by = b(VOCABSIZE,1)\n",
    "    return r,wr,wx,wy,by\n",
    "end;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define loss and its gradient\n",
    "function predict(ws,xs,bs;pdrop=0)\n",
    "    r,wr,wx,wy,by = ws\n",
    "    r = value(r)\n",
    "    x = wx[:,xs] # xs=(ΣBt) x=(X,ΣBt)\n",
    "    x = dropout(x,pdrop)\n",
    "    (y,_) = rnnforw(r,wr,x,batchSizes=bs) # y=(H,ΣBt)\n",
    "    y = dropout(y,pdrop)\n",
    "    return wy * y .+ by  # return=(V,ΣBt)\n",
    "end\n",
    "\n",
    "loss(w,x,y,b;o...) = nll(predict(w,x,b;o...), y)\n",
    "\n",
    "lossgradient = gradloss(loss);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train and test loops\n",
    "function train(model,data,optim)\n",
    "    Σ,N=0,0\n",
    "    for (x,y,b) in data\n",
    "        grads,loss1 = lossgradient(model,x,y,b;pdrop=DROPOUT)\n",
    "        update!(model, grads, optim)\n",
    "        n = length(y)\n",
    "        Σ,N = Σ+n*loss1, N+n\n",
    "    end\n",
    "    return Σ/N\n",
    "end\n",
    "\n",
    "function test(model,data)\n",
    "    Σ,N=0,0\n",
    "    for (x,y,b) in data\n",
    "        loss1 = loss(model,x,y,b)\n",
    "        n = length(y)\n",
    "        Σ,N = Σ+n*loss1, N+n\n",
    "    end\n",
    "    return Σ/N\n",
    "end;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 27.326092 seconds (21.46 M allocations: 1.116 GiB, 3.64% gc time)\n",
      "  0.900037 seconds (1.05 M allocations: 57.092 MiB, 8.54% gc time)\n",
      "  0.317074 seconds (36.53 k allocations: 3.560 MiB, 6.61% gc time)\n",
      "(1, 661.0763f0, 8.268473f6, 7.955551f6)\n",
      " 13.010777 seconds (1.99 M allocations: 121.218 MiB, 3.83% gc time)\n",
      "  0.323326 seconds (32.78 k allocations: 3.189 MiB, 7.95% gc time)\n",
      "  0.350242 seconds (36.61 k allocations: 3.571 MiB, 8.48% gc time)\n",
      "(2, 561.70776f0, 60723.773f0, 57029.04f0)\n",
      " 13.460838 seconds (1.99 M allocations: 121.222 MiB, 3.70% gc time)\n",
      "  0.275338 seconds (32.83 k allocations: 3.182 MiB, 8.25% gc time)\n",
      "  0.296452 seconds (36.66 k allocations: 3.576 MiB, 8.44% gc time)\n",
      "(3, 402.2791f0, 12841.272f0, 12044.729f0)\n",
      " 13.777386 seconds (1.99 M allocations: 121.357 MiB, 4.60% gc time)\n",
      "  0.320279 seconds (32.88 k allocations: 3.207 MiB, 11.58% gc time)\n",
      "  0.340023 seconds (36.61 k allocations: 3.540 MiB, 9.03% gc time)\n",
      "(4, 334.88797f0, 6216.086f0, 5853.3613f0)\n",
      " 12.735467 seconds (1.99 M allocations: 119.469 MiB, 4.10% gc time)\n",
      "  0.253433 seconds (32.73 k allocations: 3.199 MiB, 3.57% gc time)\n",
      "  0.277777 seconds (36.54 k allocations: 3.579 MiB, 3.83% gc time)\n",
      "(5, 292.5482f0, 3348.629f0, 3137.0308f0)\n",
      " 12.532363 seconds (1.98 M allocations: 119.526 MiB, 2.21% gc time)\n",
      "  0.256721 seconds (32.68 k allocations: 3.182 MiB, 3.22% gc time)\n",
      "  0.281500 seconds (36.47 k allocations: 3.562 MiB, 3.28% gc time)\n",
      "(6, 267.0222f0, 2565.8037f0, 2397.6816f0)\n",
      " 12.499204 seconds (1.99 M allocations: 121.200 MiB, 2.64% gc time)\n",
      "  0.268570 seconds (32.73 k allocations: 3.174 MiB, 4.04% gc time)\n",
      "  0.279359 seconds (36.47 k allocations: 3.544 MiB, 4.16% gc time)\n",
      "(7, 244.29321f0, 3607.902f0, 3366.082f0)\n",
      " 12.890701 seconds (1.99 M allocations: 121.311 MiB, 4.03% gc time)\n",
      "  0.275524 seconds (32.83 k allocations: 3.181 MiB, 7.35% gc time)\n",
      "  0.310087 seconds (36.58 k allocations: 3.564 MiB, 5.54% gc time)\n",
      "(8, 236.4967f0, 2662.8535f0, 2487.5193f0)\n",
      " 13.590224 seconds (1.99 M allocations: 121.336 MiB, 3.89% gc time)\n",
      "  0.274645 seconds (32.83 k allocations: 3.177 MiB, 8.83% gc time)\n",
      "  0.296426 seconds (36.62 k allocations: 3.547 MiB, 9.43% gc time)\n",
      "(9, 219.55128f0, 2022.9292f0, 1862.1813f0)\n",
      " 14.284127 seconds (1.99 M allocations: 120.452 MiB, 6.11% gc time)\n",
      "  0.269549 seconds (32.80 k allocations: 3.187 MiB, 5.90% gc time)\n",
      "  0.284169 seconds (36.53 k allocations: 3.571 MiB, 5.54% gc time)\n",
      "(10, 206.429f0, 1605.2661f0, 1468.9436f0)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "\"(LSTM(input=128,hidden=256,dropout=0.5), K32(1,1,395264)[-0.07042014⋯], K32(128,10000)[0.006436585⋯], K32(10000,256)[-0.061549857⋯], K32(10000,1)[-0.08249636⋯])\""
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = optim = nothing; \n",
    "Knet.gc() # free gpu memory\n",
    "if !isfile(\"rnnlm.jld2\")\n",
    "    # Initialize and train model\n",
    "    model = initmodel()\n",
    "    optim = optimizers(model,Adam,lr=LR,beta1=BETA_1,beta2=BETA_2,eps=EPS)\n",
    "    for epoch=1:EPOCHS\n",
    "        @time global j1 = train(model,mbtrn,optim)  # ~100 seconds\n",
    "        @time global j2 = test(model,mbval)         # ~4 seconds\n",
    "        @time global j3 = test(model,mbtst)         # ~4 seconds\n",
    "        println((epoch,exp(j1),exp(j2),exp(j3))); flush(stdout)  # prints perplexity = exp(negative_log_likelihood)\n",
    "    end\n",
    "    Knet.save(\"rnnlm.jld2\",\"model\",model)\n",
    "else\n",
    "    model = Knet.load(\"rnnlm.jld2\",\"model\")\n",
    "    @time global j1 = test(model,mbtrn)\n",
    "    @time global j2 = test(model,mbval)\n",
    "    @time global j3 = test(model,mbtst)\n",
    "    println((EPOCHS,exp(j1),exp(j2),exp(j3))); flush(stdout)  # prints perplexity = exp(negative_log_likelihood)\n",
    "end\n",
    "summary(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.5.0",
   "language": "julia",
   "name": "julia-1.5"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.5.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
